import asyncio
import json

import aiohttp
import tiktoken
from loguru import logger

from app.config import PERPLEXITY_API_KEY
from llm_generation.config import DEEPSEEK_MAX_TOKEN_LENGTH
from llm_generation.models.base import BaseModel
from llm_generation.models.data_structure import StreamingDelta


class Perplexity(BaseModel):
    def __init__(self, model_name: str = "sonar-reasoning-pro"):
        super().__init__(model_name)

    async def generate_response(
        self, user_prompt: str, conversation: list = None, **kwargs
    ) -> str:
        # Truncate the user prompt to MAX_TOKEN_LENGTH tokens, use tiktoken for deepseek temporarily
        tokenizer = tiktoken.encoding_for_model("gpt-4o")
        user_prompt_tokens = tokenizer.encode(user_prompt)
        if len(user_prompt_tokens) > DEEPSEEK_MAX_TOKEN_LENGTH:
            user_prompt = tokenizer.decode(
                user_prompt_tokens[:DEEPSEEK_MAX_TOKEN_LENGTH]
            )

        conversation = conversation or []
        headers = {
            "Authorization": f"Bearer {PERPLEXITY_API_KEY}",
            "Content-Type": "application/json",
        }

        payload = {
            "model": self.model_name,
            "messages": conversation + [{"role": "user", "content": user_prompt}],
            "search_recency_filter": "day",
            **kwargs,
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(
                "https://api.perplexity.ai/chat/completions",
                headers=headers,
                json=payload,
            ) as response:
                if "stream" in kwargs:
                    content = ""
                    if self.streaming_callback is None:
                        logger.warning(
                            "No streaming callback is set, skipping callback function"
                        )

                    async for line in response.content:
                        if not line:
                            continue

                        line = line.decode("utf-8").strip()
                        if not line or not line.startswith("data: "):
                            continue

                        data = line[6:]  # Remove 'data: ' prefix
                        if data == "[DONE]":
                            break

                        try:
                            chunk = json.loads(data)
                            if not chunk["choices"]:
                                continue
                            delta_dict = chunk["choices"][0]["delta"]
                            delta = StreamingDelta(**delta_dict)

                            # Call the streaming callback function
                            if self.streaming_callback:
                                if asyncio.iscoroutinefunction(self.streaming_callback):
                                    await self.streaming_callback(content, delta)
                                else:
                                    self.streaming_callback(content, delta)

                            # Append the content
                            if delta.content:
                                content += delta.content
                        except json.JSONDecodeError:
                            continue

                    return content
                else:
                    response_data = await response.json()
                    return response_data["choices"][0]["message"]["content"]


async def main():
    pass


if __name__ == "__main__":
    asyncio.run(main())
